-
-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Feature] Support Sequence Classification #2704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feature] Support Sequence Classification #2704
Conversation
|
hey @rabintiwari45 I also happened to write a pr for sequence classification: #2710, but I noticed that your PR is mainly for VLMs? correct me if I am wrong? My patch does look a bit shorter but please let me know if I am missing anything? We probably can incorporate both patches somehow. |
|
Hi @pluesclues |
|
@Etherll I believe has a way of loading any type of model into fast model as well? There maybe an easier way to load the models that you want as the PR is quite extensive but I was having a bit of trouble finding the code for it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @rabintiwari45
Thanks a lot and kudos to these changes. Added a few minor comments for now...
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| # Get outputs from the language model part only (ignore vision for sequence classification) | ||
| language_model_outputs = self.llava_next.language_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielhanchen do you think we should have a forward function creator which takes in language_model as input and performs all this
This code is similar between both the models.
So we can do LlavaNextForSequenceClassification.forward = create_forward(self.llava_next.language_model) and MllamaForSequenceClassification.forward = create_forward(self.mllama.language_model,)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @danielhanchen
when you have a moment, could you please take a look at this PR? I'd really appreciate your feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
This PR introduces support for patching AutoModelForSequenceClassification within the FastModel.from_pretrained() interface. It enables the following usage pattern:
Changes Included
Added patching logic for AutoModelForSequenceClassification to enable compatibility with FastModel.
Updated the finetuner to allow training with sequence classification models.
Modified unsloth_zoo to gracefully handle weights that do not have a quant_state attribute:
Notes
While the patch works as intended in current testing, there may be edge cases or integration concerns that require further review.
Please verify if any additional logic or edge handling is needed in related modules.
This pr is linked to #165 in unsloth_zoo